"""
Sparse Autoencoder implementation in JAX.
A pure JAX implementation of a Sparse Autoencoder with no external dependencies other than JAX and optax.
"""
import functools
import jax
import jax.numpy as jnp
from jax import lax
import optax
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, NamedTuple

# Kaiming initialization
def kaiming_uniform_initializer(scale=2.0, mode="fan_in", distribution="truncated_normal"):
    """Initialize weights with He/Kaiming uniform initialization."""
    def init_fn(key, shape, dtype=jnp.float32):
        fan_in, fan_out = shape[0], shape[1]
        denominator = fan_in if mode == "fan_in" else fan_out
        variance = scale / denominator
        
        if distribution == "truncated_normal":
            # Truncated normal with stddev = sqrt(variance)
            stddev = jnp.sqrt(variance)
            return jax.random.truncated_normal(key, -2, 2, shape, dtype) * stddev
        elif distribution == "uniform":
            # Uniform distribution
            bound = jnp.sqrt(3.0 * variance)
            return jax.random.uniform(key, shape, dtype, -bound, bound)
        else:
            raise ValueError(f"Unsupported distribution: {distribution}")
            
    return init_fn

class TopKActivation(NamedTuple):
    """Layer that applies top-k activation."""
    k: int
    
    def apply(self, params, x):
        # Compute the top-k mask using the raw data
        sorted_x = -jnp.sort(-jnp.abs(x), axis=-1)
        k_th_largest = jnp.expand_dims(sorted_x[..., self.k - 1], axis=-1)
        topk_mask = jnp.abs(x) >= k_th_largest
        
        # Apply the mask to the raw data
        topk_activated = jnp.where(topk_mask, x, 0)
        
        return topk_activated, topk_mask

class Linear(NamedTuple):
    """Linear layer."""
    params: Dict
    
    @classmethod
    def init(cls, key, input_dim, output_dim, initializer=None):
        if initializer is None:
            initializer = kaiming_uniform_initializer()
            
        weights_key, bias_key = jax.random.split(key)
        weights = initializer(weights_key, (input_dim, output_dim))
        bias = jnp.zeros((output_dim,))
        
        return cls(params={"weights": weights, "bias": bias})
    
    def apply(self, params, x):
        return jnp.dot(x, params["weights"]) + params["bias"]

class SparseAutoencoder(NamedTuple):
    """Sparse Autoencoder model."""
    params: Dict
    k: int
    embed_dim: int
    hidden_dim: int
    
    @classmethod
    def init(cls, key, embed_dim, hidden_dim, k, bias_init=0.0):
        encoder_key, decoder_key, tied_bias_key = jax.random.split(key, 3)
        
        # Initialize components
        encoder = Linear.init(encoder_key, embed_dim, hidden_dim)
        decoder = Linear.init(decoder_key, hidden_dim, embed_dim)
        
        # Tied bias
        tied_bias = jnp.full((embed_dim,), bias_init)
        
        params = {
            "encoder": encoder.params,
            "decoder": decoder.params,
            "tied_bias": tied_bias
        }
        
        return cls(params=params, k=k, embed_dim=embed_dim, hidden_dim=hidden_dim)
    
    def apply(self, params, x, return_intermediates=False):
        # Apply the negative of the tied bias
        x_minus_bias = x - params["tied_bias"]
        
        # Encode
        encoded = jnp.dot(x_minus_bias, params["encoder"]["weights"]) + params["encoder"]["bias"]
        
        # Apply TopK activation
        topk_activation = TopKActivation(k=self.k)
        activated, topk_mask = topk_activation.apply(None, encoded)
        
        # Decode
        decoded = jnp.dot(activated, params["decoder"]["weights"]) + params["decoder"]["bias"]
        
        # Add tied bias back
        output = decoded + params["tied_bias"]
        
        if return_intermediates:
            return output, {
                "pre_activation": encoded,
                "post_activation": activated,
                "topk_mask": topk_mask
            }
        return output

class InactiveLatentTracker:
    """
    Tracks and updates the status of latents in a model to determine which have been inactive.
    """
    def __init__(self, num_latents, inactive_threshold):
        self.num_latents = num_latents
        self.inactive_threshold = inactive_threshold
        self.zero_counts = jnp.zeros(num_latents, dtype=jnp.int32)
        self.total_tokens_seen = 0
    
    @classmethod
    def create(cls, num_latents, inactive_threshold):
        """Create a new tracker instance."""
        return cls(num_latents, inactive_threshold)
    
    def update(self, top_active_mask):
        """Update the tracker based on the latest batch of activation data."""
        active_any = jnp.any(top_active_mask, axis=0)
        new_zero_counts = jnp.where(active_any, 0, self.zero_counts + top_active_mask.shape[0])
        new_total_tokens_seen = self.total_tokens_seen + top_active_mask.shape[0]
        
        tracker = InactiveLatentTracker(self.num_latents, self.inactive_threshold)
        tracker.zero_counts = new_zero_counts
        tracker.total_tokens_seen = new_total_tokens_seen
        return tracker
    
    def get_inactive_latents_mask(self):
        """Determine which latents are currently inactive based on the threshold."""
        return self.zero_counts >= self.inactive_threshold
    
    def get_top_k_inactive_latents(self, k_aux):
        """Retrieve the indices of the top inactive latents based on their inactivity counts."""
        inactive_mask = self.get_inactive_latents_mask()
        if jnp.any(inactive_mask):
            inactive_counts = jnp.where(inactive_mask, self.zero_counts, -jnp.inf)
            sorted_indices = jnp.argsort(inactive_counts)[::-1]
            top_indices = sorted_indices[:k_aux]
            return top_indices
        else:
            return None

@jax.jit
def jit_safe_top_k(x, k):
    """JIT-safe implementation of top-k."""
    # Sort the input in descending order
    sorted_x = -jnp.sort(-x, axis=-1)
    # Get the k-th largest value
    k_th_largest = jnp.expand_dims(sorted_x[..., k - 1], axis=-1)
    # Create a mask for the top-k values
    top_k_mask = x >= k_th_largest
    # Get the top-k values
    top_k_values = jnp.where(top_k_mask, x, -jnp.inf)
    return top_k_values, top_k_mask

def loss_fn(
    params,
    inputs,
    targets,
    k,
    embed_dim,
    hidden_dim,
    top_k_inactive=None,
):
    """
    Compute the loss for the SparseAutoencoder.
    
    Args:
        params: Model parameters
        inputs: Input data
        targets: Target data (usually same as input for autoencoders)
        k: Number of active units
        embed_dim: Embedding dimension
        hidden_dim: Hidden dimension
        top_k_inactive: Indices of top inactive latents (optional)
        
    Returns:
        tuple: (total_loss, metrics)
    """
    # Forward pass manually instead of using model.apply
    # Apply the negative of the tied bias
    x_minus_bias = inputs - params["tied_bias"]
    
    # Encode
    encoded = jnp.dot(x_minus_bias, params["encoder"]["weights"]) + params["encoder"]["bias"]
    
    # Apply TopK activation
    sorted_x = -jnp.sort(-jnp.abs(encoded), axis=-1)
    k_th_largest = jnp.expand_dims(sorted_x[..., k - 1], axis=-1)
    topk_mask = jnp.abs(encoded) >= k_th_largest
    activated = jnp.where(topk_mask, encoded, 0)
    
    # Decode
    decoded = jnp.dot(activated, params["decoder"]["weights"]) + params["decoder"]["bias"]
    
    # Add tied bias back
    model_out = decoded + params["tied_bias"]
    
    # Save intermediates
    pre_selection_latents = encoded
    top_active_mask = topk_mask
    
    # Compute reconstruction error
    e = model_out - targets
    reconstruction_error = jnp.mean(jnp.sum(jnp.square(e), axis=-1))
    
    # Calculate total variance for FVU (Fraction of Variance Unexplained)
    total_variance = jnp.sum(jnp.square(targets - jnp.mean(targets, axis=0)))
    
    # Calculate FVU
    fvu = reconstruction_error / (total_variance + 1e-8)
    
    # Calculate auxiliary loss if inactive latents are provided
    if top_k_inactive is not None and len(top_k_inactive) > 0:
        dead_mask = jnp.zeros(pre_selection_latents.shape[1], dtype=bool).at[top_k_inactive].set(True)
        
        # Set active latents to -inf, keep dead latents as is
        dead_latents = jnp.where(dead_mask, pre_selection_latents, -jnp.inf)
        
        # Select top k_aux of the dead latents
        d_model = inputs.shape[1]
        k_aux = d_model // 2  # Heuristic from Appendix B.1 in the paper
        k_aux = jnp.minimum(k_aux, dead_mask.sum())  # Ensure k_aux is not larger than number of dead latents
        
        top_dead_values, top_dead_mask = jit_safe_top_k(dead_latents, k_aux)
        
        # Create sparse representation of top dead latents
        sparse_dead_latents = jnp.where(top_dead_mask, top_dead_values, 0)
        
        # Reconstruct using only the decoder part
        dead_decoded = jnp.dot(sparse_dead_latents, params["decoder"]["weights"]) + params["decoder"]["bias"]
        dead_reconstruction = dead_decoded + params["tied_bias"]
        
        # Calculate dead latent reconstruction error
        e_hat = dead_reconstruction - targets
        
        num_inactive = jnp.sum(dead_mask)
        scale = jnp.minimum(num_inactive / k_aux, 1.0)
        
        aux_loss = jnp.mean(jnp.square(e - e_hat))
        aux_loss = aux_loss * scale / (total_variance + 1e-8)
    else:
        aux_loss = 0.0
    
    alpha = 1/32
    aux_loss *= alpha
    
    # Ensure all losses are finite
    reconstruction_error = jnp.nan_to_num(reconstruction_error, nan=jnp.inf, posinf=jnp.inf, neginf=-jnp.inf)
    aux_loss = jnp.nan_to_num(aux_loss, nan=0.0, posinf=jnp.inf, neginf=-jnp.inf)
    
    total_loss = reconstruction_error + aux_loss
    
    metrics = {
        "pre_selection_latents": pre_selection_latents,
        "top_active_mask": top_active_mask,
        "Total Loss": total_loss,
        "Reconstruction Error": reconstruction_error,
        "Auxiliary Loss": aux_loss,
        "FVU": fvu,
    }
    
    return total_loss, metrics

def data_generator(data, batch_size):
    """
    Generate batches of data.
    
    Args:
        data (jax.Array): Full dataset.
        batch_size (int): Size of each batch.
        
    Yields:
        jax.Array: A batch of data.
    """
    num_samples = data.shape[0]
    indices = jax.random.permutation(jax.random.PRNGKey(0), num_samples)
    for start in range(0, num_samples, batch_size):
        end = min(start + batch_size, num_samples)
        batch_indices = indices[start:end]
        yield data[batch_indices]

def tensorstore_data_generator(dataset, idxs, batch_size):
    """
    Generate batches of data from a tensorstore database.
    
    Args:
        dataset: TensorStore dataset
        batch_size (int): Size of each batch.
        idxs (jax.Array): List of indices to slice from the dataset.
        
    Yields:
        jax.Array: A batch of data.
    """
    to_load = idxs[0:batch_size*100]
    outer, inner = to_load[:,0], to_load[:,1]
    next_table = dataset[outer,inner].read()
    for start in range(batch_size*100, len(idxs), batch_size*100):
        curr_table = next_table.result()
        to_load = idxs[start:min(start+batch_size*100, idxs.shape[0])]
        outer, inner = to_load[:,0], to_load[:,1]
        next_table = dataset[outer,inner].read()
        for i in range(100):
            yield jnp.array(curr_table[i*batch_size:min((i+1)*batch_size, len(curr_table))])

def train_step(params, opt_state, batch, top_k_inactive, k, embed_dim, hidden_dim, optimizer):
    """Single training step."""
    def loss_for_step(p):
        # Use only the necessary components of the model
        loss, metrics = loss_fn(
            p, 
            batch, 
            batch, 
            k, 
            embed_dim,
            hidden_dim,
            top_k_inactive
        )
        return loss, metrics
    
    (loss, metrics), grads = jax.value_and_grad(loss_for_step, has_aux=True)(params)
    updates, new_opt_state = optimizer.update(grads, opt_state, params)
    new_params = optax.apply_updates(params, updates)
    
    return new_params, new_opt_state, loss, metrics

# Create a jitted version that handles static arguments properly
@functools.partial(jax.jit, static_argnames=('k', 'embed_dim', 'hidden_dim', 'optimizer'))
def jitted_train_step(params, opt_state, batch, top_k_inactive, k, embed_dim, hidden_dim, optimizer):
    return train_step(params, opt_state, batch, top_k_inactive, k, embed_dim, hidden_dim, optimizer)

def train(
    model,
    inputs_array,
    num_epochs,
    batch_size=4096,
    learning_rate=1e-3,
    log_every=100,
    data_gen=data_generator,
    inactive_threshold=10_000
):
    """
    Train the SparseAutoencoder.
    
    Args:
        model (SparseAutoencoder): The model to train
        inputs_array (jax.Array): Input data
        num_epochs (int): Number of training epochs
        batch_size (int): Batch size
        learning_rate (float): Learning rate
        log_every (int): Log metrics every n steps
        data_gen (function): Data generator function
        inactive_threshold (int): Threshold for considering a latent as inactive
        
    Returns:
        SparseAutoencoder: The trained model
    """
    # Setup optimizer
    optimizer = optax.adam(learning_rate)
    opt_state = optimizer.init(model.params)
    
    # Extract model parameters
    k = model.k
    embed_dim = model.embed_dim
    hidden_dim = model.hidden_dim
    
    # Calculate k_aux
    k_aux = embed_dim // 2  # Heuristic from Appendix B.1 in the paper
    
    # Setup inactive latent tracker
    inactive_latent_tracker = InactiveLatentTracker.create(hidden_dim, inactive_threshold=inactive_threshold)
    
    nan_loss_count = 0
    inf_loss_count = 0
    
    params = model.params
    step = 0
    
    for epoch in range(num_epochs):
        for i, batch in enumerate(data_gen(inputs_array, batch_size)):
            top_k_inactive = inactive_latent_tracker.get_top_k_inactive_latents(k_aux)
            
            # Perform training step
            params, opt_state, loss, metrics = jitted_train_step(
                params, 
                opt_state, 
                batch, 
                top_k_inactive, 
                k, 
                embed_dim, 
                hidden_dim, 
                optimizer
            )
            
            # Update inactive latent tracking
            previously_inactive_mask = inactive_latent_tracker.get_inactive_latents_mask()
            inactive_latent_tracker = inactive_latent_tracker.update(metrics["top_active_mask"])
            currently_inactive_mask = inactive_latent_tracker.get_inactive_latents_mask()
            revived_latents = jnp.sum(jnp.logical_and(previously_inactive_mask, ~currently_inactive_mask))
            
            # Check for NaN or Inf in losses
            if jnp.isnan(metrics["Total Loss"]):
                nan_loss_count += 1
            if jnp.isinf(metrics["Total Loss"]):
                inf_loss_count += 1
            
            step += 1
            
            # Log metrics
            if step % log_every == 0:
                log_metrics = {
                    "epoch": epoch + 1,
                    "step": step,
                    "total_loss": metrics["Total Loss"],
                    "reconstruction_error": metrics["Reconstruction Error"],
                    "auxiliary_loss": metrics["Auxiliary Loss"],
                    "fvu": metrics["FVU"],
                    "inactive_latents": jnp.sum(currently_inactive_mask),
                    "revived_latents": revived_latents,
                    "nan_loss_count": nan_loss_count,
                    "inf_loss_count": inf_loss_count,
                }
                
                print(f"Epoch {epoch+1}, Step {step}: {log_metrics}")
                
                # Debug prints
                print(f"Pre-selection latents stats: min={metrics['pre_selection_latents'].min()}, "
                      f"max={metrics['pre_selection_latents'].max()}, "
                      f"mean={metrics['pre_selection_latents'].mean()}")
                print(f"Top-k mask sum: {metrics['top_active_mask'].sum()}")
    
    # Create a new model with the trained parameters
    trained_model = SparseAutoencoder(
        params=params,
        k=model.k,
        embed_dim=model.embed_dim,
        hidden_dim=model.hidden_dim
    )
    
    return trained_model

# Function to get sparse representations and reconstructions
def get_sparse_representations(model, inputs, batch_size=1024):
    """
    Get sparse representations and reconstructions for the input data.
    
    Args:
        model: SparseAutoencoder model
        inputs: Input data
        batch_size: Batch size for processing
        
    Returns:
        Tuple of (sparse_codes, reconstructions)
    """
    @jax.jit
    def process_batch(params, batch, k):
        # Forward pass manually
        # Apply the negative of the tied bias
        x_minus_bias = batch - params["tied_bias"]
        
        # Encode
        encoded = jnp.dot(x_minus_bias, params["encoder"]["weights"]) + params["encoder"]["bias"]
        
        # Apply TopK activation
        sorted_x = -jnp.sort(-jnp.abs(encoded), axis=-1)
        k_th_largest = jnp.expand_dims(sorted_x[..., k - 1], axis=-1)
        topk_mask = jnp.abs(encoded) >= k_th_largest
        activated = jnp.where(topk_mask, encoded, 0)
        
        # Decode
        decoded = jnp.dot(activated, params["decoder"]["weights"]) + params["decoder"]["bias"]
        
        # Add tied bias back
        output = decoded + params["tied_bias"]
        
        return activated, output
    
    # Process data in batches
    num_samples = inputs.shape[0]
    codes = []
    reconstructions = []
    
    for i in range(0, num_samples, batch_size):
        batch = inputs[i:min(i+batch_size, num_samples)]
        batch_code, batch_reconstruction = process_batch(model.params, batch, model.k)
        codes.append(batch_code)
        reconstructions.append(batch_reconstruction)
        if (i + batch_size) % (10 * batch_size) == 0:
            print(f"Processed {min(i+batch_size, num_samples)}/{num_samples} samples")
    
    return jnp.concatenate(codes), jnp.concatenate(reconstructions)

# Function to analyze dictionary weights
def analyze_dictionary(model):
    """
    Analyze the dictionary (decoder weights) of the model.
    
    Args:
        model: SparseAutoencoder model
        
    Returns:
        Dict containing normalized dictionary and analysis results
    """
    # Get decoder weights
    D = model.params["decoder"]["weights"]
    
    # Normalize rows
    D_normalized = D / jnp.linalg.norm(D, axis=1, keepdims=True)
    
    # Compute pairwise cosine similarities
    cosine_sims = jnp.dot(D_normalized, D_normalized.T)
    
    # Analyze similarities
    cosine_sims_no_diag = cosine_sims.at[jnp.diag_indices_from(cosine_sims)].set(0)
    avg_sim = jnp.mean(jnp.abs(cosine_sims_no_diag))
    max_sim = jnp.max(jnp.abs(cosine_sims_no_diag))
    
    results = {
        "normalized_dictionary": D_normalized,
        "cosine_similarities": cosine_sims,
        "average_similarity": avg_sim,
        "maximum_similarity": max_sim
    }
    
    return results